"""This is the file that does expectations with foresight, discounting, etc."""

import numpy as np
import matplotlib.pyplot as plt

from numba import njit
from scipy.optimize import minimize 

from joblib import Parallel, delayed

np.set_printoptions(precision=3)

@njit
def output(k, static_inputs, nu, rho, om_E):
    return (k**rho + static_inputs[0]**rho + (om_E*static_inputs[1])**rho)**(nu/rho) #Output

def static_profit(k, inputs, nu, rho, om_E, prices):
    return prices[0]*output(k, np.array(inputs), nu, rho, om_E) - np.dot(inputs, prices[1:]) #Static profit - net of K

def incumbent_profit(k0, prices, om_E, g0, g1, g2, nu, rho, beta, profit_only=False):
    
    PV = 1/(1-beta) #Without depreciation, if we stay in we stay in forever
    #profit if capital is adjusted
    Pi = lambda x : -(PV*static_profit(x[0], x[1:], nu, rho, om_E, prices) - g0 - g1*(x[0]-k0) - g2*(x[0]-k0)**2)
    res_adj = minimize(Pi, [k0, 100, 100], bounds=[(1e-6,1e6), (1e-6,1e6), (1e-6,1e6)])
    
    #profit if capital is not adjusted
    Pi = lambda x : -(PV*static_profit(k0, x, nu, rho, om_E, prices))
    res_fix = minimize(Pi, [100, 100], bounds=[(1e-6,1e6), (1e-6,1e6)])

    #Sell all capital and shut down
    Pi_exit = g1*k0 #- g0 - g2*k0**2
   

    if (Pi_exit >= -res_fix.fun) and (Pi_exit >= -res_adj.fun): #Exit is best
        profit = Pi_exit
        #print(k0, Pi_exit)
        k1, l1, e1, Y, e1_Y = 0,0,0,0,-1
        
        opt = -1 #code for exit

    elif -res_adj.fun > -res_fix.fun: #Adjusting is better
        profit = -res_adj.fun
        k1, l1, e1 = res_adj.x
        #print('adjust', k0, k1, 'cost:', g0,g1,g2,g0+g1*(k1-k0)+g2*(k1-k0)**2, 'profits:', profit, 'vs', -res_fix.fun)
        Y = output(k1, np.array([l1, e1]), nu, rho, om_E)
        e1_Y = e1/Y
        opt = 1 #Code for adjusted

    else:
        
        k1 = k0 #Capital Unchanged
        profit = -res_fix.fun
        l1, e1 = res_fix.x
        Y = output(k1, np.array([l1, e1]), nu, rho, om_E)
        e1_Y = e1/Y
        opt = 2 #Code for unadjusted

    #if not profit_only and (g0 > 0):
    #    print(om_E, k0, ':', -res_adj.fun, -res_fix.fun, Pi_exit, opt)
    
    #print(res.x, -res.fun)
    if profit_only:
        return profit

    return np.array([om_E, k1, l1, e1, Y, e1_Y, profit]), opt

def expected_inc_profit(k0, prices1, dist_p1, om_E, g0, g1, g2, nu, rho, beta):
    Epi = 0
    for i in range(len(prices1)):
        Epi += dist_p1[i]*incumbent_profit(k0, prices1[i], om_E, g0, g1, g2, nu, rho, beta, profit_only=True)
    return Epi

def entrant_profit(prices0, om_E, g0, g1, g2, nu, rho, beta, prices1, dist_p1):
    Pi = lambda x : -(static_profit(x[0], x[1:], nu, rho, om_E, prices0) - g1*x[0]
                        + beta*expected_inc_profit(x[0], prices1, dist_p1, om_E, g0, g1, g2, nu, rho, beta))
    res = minimize(Pi, [10, 10, 10], bounds=[(1e-6,1e6), (1e-6,1e6), (1e-6,1e6)])

    k0, l0, e0 = res.x
    Y = output(k0, np.array([l0, e0]), nu, rho, om_E)
    return np.array([om_E, k0, l0, e0, Y, e0/Y, -res.fun])

def industry_entrants(prices0, g0, g1, g2, nu, rho, beta, FC,  prices1, dist_p1, om_range=[0.5,8], om_res=250):
    om_Es  = np.linspace(om_range[0], om_range[1], om_res)
    totals = np.empty((om_res, 7))
    results = Parallel(n_jobs=6)(delayed(entrant_profit)(prices0, om_E, g0, g1, g2, nu, rho, beta, prices1, dist_p1)
                                            for om_E in om_Es)
    for i, om_E in enumerate(om_Es): #Unpack
        totals[i,:] = results[i]
    
    entrants = (totals[:,-1] > FC)*(totals[:,3]>0) #Profit exceeds fixed costs, and is operating
    print(np.sum(entrants), '/', om_res, 'entrants have profit exceeding fixed costs')

    return totals, entrants


def price_change_curve(g0, g1, g2, nu, rho, beta, FC, dist_p1, p_res, om_res=250):
    P1_range = np.ones((p_res, 2, 3))
    P1_range[:,:, 0] =  3 #Price of Ouput
    P1_range[:,:, 1:] = 1 #Price of Labor = 1, price of energy = 1 for period 1
    P1_range[:,0,-1] = np.linspace(0.25, 1.25, p_res) #Set price in period 0

    avg_intensity = np.zeros((p_res, 4)) #Fixed, friction, and flexible incumbents, new entrants
    for i in range(p_res): #For all possible increases
        print(P1_range[i])
        # Get initial entry, capital choices
        results, entrants = industry_entrants(P1_range[i,0,:], g0, g1, g2, nu, rho, beta, FC, P1_range[i], dist_p1, om_res=om_res)
        #incumbent decisions
        inc_fixd = np.zeros_like(results)
        inc_fric = np.zeros_like(results)
        inc_flex = np.zeros_like(results)
        
        k_opt = np.zeros((3,len(entrants)))
        
        for k in range(len(results)):
            #print()
            if entrants[k]: #Plant entered in first place
                inc_fixd[k,:], k_opt[0,k] = incumbent_profit(results[k,1], P1_range[i,1,:], results[k,0], 1e9, g1, 1e9, nu, rho, beta) #Infinite Adjustment Costs
                inc_fric[k,:], k_opt[1,k] = incumbent_profit(results[k,1], P1_range[i,1,:], results[k,0], g0, g1, g2, nu, rho, beta) #Adjustment Costs
                inc_flex[k,:], k_opt[2,k] = incumbent_profit(results[k,1], P1_range[i,1,:], results[k,0],  0, g1,  0, nu, rho, beta) #No adjustment costs
        

        print('k_opt', P1_range[i,0,:])
        print(k_opt)
        print()
        extant_fixd = entrants & (inc_fixd[:,1]>0) #Entered and did not exit
        extant_fric = entrants & (inc_fric[:,1]>0) #Entered and did not exit
        extant_flex = entrants & (inc_flex[:,1]>0) #Entered and did not exit
        avg_intensity[i,0] = np.mean(inc_fixd[extant_fixd,-2])
        avg_intensity[i,1] = np.mean(inc_fric[extant_fric,-2])
        avg_intensity[i,2] = np.mean(inc_flex[extant_flex,-2])
    #New entrants are the same every time?
    res, entry = industry_entrants(P1_range[i,-1,:], g0, g1, g2, nu, rho, beta, FC, P1_range[i,-1:,:], np.array([1]), om_res=om_res) #New prices persist for sure
    avg_intensity[:,3] = np.mean(res[entry,-2])
    return avg_intensity, P1_range[:,0,-1]




g0 = 0.08 #Fixed costs of changing capital stock
g1 = 1 #Cost of capital
g2 = 0.02 #Convex costs

nu   = 0.9   #returns to scale
rho  = -3 #Substitution

FC = 5 #Fixed cost of entry

P = 0     #Probability of price increase

if P == 0:
    title = 'unanticipated'
elif P == 1:
    title = 'perfect foresight'
else:
    title = 'P( price change ) = '+str(P)

dist_p1 = np.array([1-P, P])  #Firms beliefs about distribution of prices. 
p_res = 20 #Number of prices sampled
om_res = 500 #Number of energy productivity points sampled


betas = [0.6,0.8] #Discount Rates
for beta in betas:
    avg_intensity, p_e = price_change_curve(g0, g1, g2, nu, rho, beta, FC, dist_p1, p_res, om_res)
    np.save('Saved_Data/'+title+'_beta='+str(round(100*beta)), avg_intensity)


